# 这个文件是关于 pytorch 的一些测试
import torch
from torch.nn.functional import softmax

if __name__ == '__main__':
    a_tensor = torch.tensor([1, 2, 3, 4, 5, 6])
    b_tensor = a_tensor.view(2, 3)
    print(a_tensor)  # tensor([1, 2, 3, 4, 5, 6])
    print(b_tensor)
    # tensor([[1, 2, 3],
    #         [4, 5, 6]])
    pass


if __name__ == '__main__':
    a_tensor = torch.linspace(1, 4, 24)
    a_tensor = a_tensor.view(2, 3, 4)
    print(a_tensor)
    # tensor([[[1.0000, 1.1304, 1.2609, 1.3913],
    #          [1.5217, 1.6522, 1.7826, 1.9130],
    #          [2.0435, 2.1739, 2.3043, 2.4348]],
    #
    #         [[2.5652, 2.6957, 2.8261, 2.9565],
    #          [3.0870, 3.2174, 3.3478, 3.4783],
    #          [3.6087, 3.7391, 3.8696, 4.0000]]])
    print(softmax(a_tensor, dim=0))
    # tensor([[[0.1729, 0.1729, 0.1729, 0.1729],
    #          [0.1729, 0.1729, 0.1729, 0.1729],
    #          [0.1729, 0.1729, 0.1729, 0.1729]],
    #
    #         [[0.8271, 0.8271, 0.8271, 0.8271],
    #          [0.8271, 0.8271, 0.8271, 0.8271],
    #          [0.8271, 0.8271, 0.8271, 0.8271]]])
    print(softmax(a_tensor, dim=1))
    # tensor([[[0.1810, 0.1810, 0.1810, 0.1810],
    #          [0.3050, 0.3050, 0.3050, 0.3050],
    #          [0.5139, 0.5139, 0.5139, 0.5139]],
    #
    #         [[0.1810, 0.1810, 0.1810, 0.1810],
    #          [0.3050, 0.3050, 0.3050, 0.3050],
    #          [0.5139, 0.5139, 0.5139, 0.5139]]])
    print(softmax(a_tensor, dim=2))
    # tensor([[[0.2034, 0.2317, 0.2640, 0.3008],
    #          [0.2034, 0.2317, 0.2640, 0.3008],
    #          [0.2034, 0.2317, 0.2640, 0.3008]],
    #
    #         [[0.2034, 0.2317, 0.2640, 0.3008],
    #          [0.2034, 0.2317, 0.2640, 0.3008],
    #          [0.2034, 0.2317, 0.2640, 0.3008]]])
    pass

